Skip to content

skip h2d and d2h copies between forward functions in gemma4-31b#20286

Open
Gasoonjia wants to merge 3 commits into
mainfrom
export-D108661628
Open

skip h2d and d2h copies between forward functions in gemma4-31b#20286
Gasoonjia wants to merge 3 commits into
mainfrom
export-D108661628

Conversation

@Gasoonjia

@Gasoonjia Gasoonjia commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Summary: This diff updates gemma4-31b export and runtime pipeline to skip the h2d and d2h copies between prefill and decode, and between previous round next round of decode as well.

  MAIN Current
total H2D (d=128) 388 / 417 µs 5 / 7.4 µs
per-round H2D 3 0
decode tok/s 45.86 47.02

Differential Revision: D108661628

@pytorch-bot

pytorch-bot Bot commented Jun 15, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20286

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 35 Pending, 1 Unrelated Failure, 2 Unclassified Failures

As of commit ef29abd with merge base 48ff29e (image):

NEW FAILURE - The following job has failed:

UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 15, 2026
@meta-codesync

meta-codesync Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

@Gasoonjia has exported this pull request. If you are a Meta employee, you can view the originating Diff in D108661628.

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@Gasoonjia Gasoonjia changed the title skip h2d and d2h copies methods skip h2d and d2h copies between forward functions in gemma4-31b Jun 15, 2026
Gasoonjia and others added 3 commits June 17, 2026 00:59
Summary: This diff updates gemma4-31b export and runtime pipeline to skip the h2d and d2h copies between prefill and decode, and between previous round next round of decode as well.

Differential Revision: D108661628
…ias)

Eliminate the per-decode-round token D2H->H2D round-trip:
- sampler.py: emit the sampled token as int64 (was float32) so the decode
  method's int64 token output can be aliased directly as the next forward's
  int64 token input (value-preserving: argmax index; token ids < 2^24).
- main.cpp: read_token reads int64; each forward's on-device output token is
  aliased via make_tensor_ptr and fed straight back as the next step's token
  input (prefill->decode and decode->decode). Only the per-round position
  H2D remains.

Measured (int6/gguf, cuda graph OFF, p19/d128): post-load HtoD 261->132
(token H2D removed; ~= decode length); DtoH/DtoD counts unchanged (129),
bytes 4B->8B (token now int64). Greedy output byte-identical to prior export.
Kill the per-decode-round position H2D (the last per-round host->device copy
left after Option A): upload the full decode position array to device once
(single H2D), then each step copy that step's position from the array into the
fixed position input slot with an on-device D2D. Token stays aliased on device
(Option A). Per-round HtoD is now 0, independent of decode length; the fixed
input slot keeps it cuda-graph-safe (with cuda graph on, the D2D becomes a
captured cudaMemcpyAsync on the decode stream into the same slot).

Measured (int6/gguf, cuda graph OFF, p19/d128): post-load HtoD 132->5
(per-round H2D=0); DtoD 129->257 (+128 per-round pos d2d, the intended
H2D->d2d trade); DtoH unchanged (129). Greedy output byte-identical to prior
runs. Runner-only; reuses the int64-output export (no re-export).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant